{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fe2fbcf0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Requirement already satisfied: accelerate==0.25.0 in /opt/conda/lib/python3.10/site-packages (0.25.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (1.23.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (21.3)\n",
      "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (5.9.3)\n",
      "Requirement already satisfied: pyyaml in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (6.0)\n",
      "Requirement already satisfied: torch>=1.10.0 in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (2.2.0)\n",
      "Requirement already satisfied: huggingface-hub in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (0.20.1)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from accelerate==0.25.0) (0.4.1)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->accelerate==0.25.0) (3.0.9)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch>=1.10.0->accelerate==0.25.0) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate==0.25.0) (12.3.101)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->accelerate==0.25.0) (2.28.1)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub->accelerate==0.25.0) (4.66.1)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate==0.25.0) (2.1.3)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->accelerate==0.25.0) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->accelerate==0.25.0) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->accelerate==0.25.0) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->huggingface-hub->accelerate==0.25.0) (2022.6.15)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch>=1.10.0->accelerate==0.25.0) (1.3.0)\n",
      "Requirement already satisfied: transformers==4.36.0 in /opt/conda/lib/python3.10/site-packages (4.36.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (3.13.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.20.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (1.23.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (21.3)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (6.0)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (2022.7.9)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (2.28.1)\n",
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.15.2)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.4.1)\n",
      "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (4.66.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.36.0) (2023.12.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.36.0) (4.9.0)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->transformers==4.36.0) (3.0.9)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (2022.6.15)\n",
      "Requirement already satisfied: datasets==2.3.2 in /opt/conda/lib/python3.10/site-packages (2.3.2)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (1.23.4)\n",
      "Requirement already satisfied: pyarrow>=6.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (8.0.0)\n",
      "Requirement already satisfied: dill<0.3.6 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (0.3.5.1)\n",
      "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (1.4.3)\n",
      "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (2.28.1)\n",
      "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (4.66.1)\n",
      "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (3.0.0)\n",
      "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (0.70.13)\n",
      "Requirement already satisfied: fsspec>=2021.05.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]>=2021.05.0->datasets==2.3.2) (2023.12.2)\n",
      "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (3.8.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (0.20.1)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (21.3)\n",
      "Requirement already satisfied: responses<0.19 in /opt/conda/lib/python3.10/site-packages (from datasets==2.3.2) (0.18.0)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (23.2.0)\n",
      "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (2.1.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (6.0.2)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (4.0.2)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (1.7.2)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (1.3.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets==2.3.2) (1.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets==2.3.2) (3.13.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets==2.3.2) (6.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets==2.3.2) (4.9.0)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging->datasets==2.3.2) (3.0.9)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.3.2) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.3.2) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.19.0->datasets==2.3.2) (2022.6.15)\n",
      "Requirement already satisfied: python-dateutil>=2.8.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets==2.3.2) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets==2.3.2) (2022.1)\n",
      "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets==2.3.2) (1.16.0)\n",
      "Requirement already satisfied: numpy==1.23.4 in /opt/conda/lib/python3.10/site-packages (1.23.4)\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install accelerate==0.25.0\n",
    "!pip install transformers==4.36.0\n",
    "!pip install datasets==2.3.2\n",
    "!pip install numpy==1.23.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "62d977e0-8760-46bb-b08f-b335fbeb5a79",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "Using custom data configuration default\n",
      "Reusing dataset rotten_tomatoes (/home/ashish.jha/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3358322d2b0c41ca9058fba51b9b9fd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Parameter 'function'=<function tokenize_function at 0x7fbc826b0040> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "75b23b72a9854a4fb99fe68d79e66cb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/9 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b4f985d0eac4fada1bfeeb0a90a16fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a41bbf4d64d4c9cb81463096ca6d3a1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import BertTokenizer\n",
    "from accelerate import Accelerator\n",
    "\n",
    "# Initialize the accelerator\n",
    "accelerator = Accelerator(cpu=False, mixed_precision=\"fp16\")\n",
    "\n",
    "# Loading a dataset from HuggingFace Datasets library\n",
    "dataset = load_dataset(\"rotten_tomatoes\")\n",
    "\n",
    "# Initializing a tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "# Tokenizing and preparing the dataset for PyTorch\n",
    "def tokenize_function(example):\n",
    "    return tokenizer(example[\"text\"], padding=\"max_length\", truncation=True)\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
    "tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])\n",
    "\n",
    "# Creating PyTorch DataLoader\n",
    "train_dataloader = torch.utils.data.DataLoader(tokenized_dataset['train'], batch_size=8, shuffle=True)\n",
    "eval_dataloader = torch.utils.data.DataLoader(tokenized_dataset['test'], batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7de76a32-8fe4-43e0-9c17-f9b4ca259af1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/opt/conda/lib/python3.10/site-packages/transformers/optimization.py:429: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "  0%|                                                                                                                                                                                                                                                                      | 0/1067 [00:00<?, ?it/s]2024-02-16 16:00:57.026641: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-02-16 16:00:57.026707: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-02-16 16:00:57.028280: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1067/1067 [05:14<00:00,  3.39it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:11<00:00, 11.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 - Evaluation Accuracy: 0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1067/1067 [05:07<00:00,  3.47it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:12<00:00, 11.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 - Evaluation Accuracy: 0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1067/1067 [05:07<00:00,  3.47it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134/134 [00:11<00:00, 11.28it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 - Evaluation Accuracy: 0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from transformers import BertForSequenceClassification, AdamW\n",
    "\n",
    "# Load pre-trained BERT model for sequence classification\n",
    "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "# Move model to the device managed by Accelerator\n",
    "model, train_dataloader, eval_dataloader = accelerator.prepare(model, train_dataloader, eval_dataloader)\n",
    "\n",
    "# Optimizer and learning rate scheduler setup\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
    "\n",
    "# Training loop using PyTorch\n",
    "for epoch in range(3):  # Train for 3 epochs as an example\n",
    "    model.train()\n",
    "    for batch in tqdm(train_dataloader):\n",
    "        optimizer.zero_grad()\n",
    "        input_ids = batch['input_ids']\n",
    "        attention_mask = batch['attention_mask']\n",
    "        labels = batch['label']\n",
    "\n",
    "        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "        optimizer.step()\n",
    "\n",
    "    # Evaluation loop\n",
    "    model.eval()\n",
    "    total_correct = 0\n",
    "    total_samples = 0\n",
    "    for batch in tqdm(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            input_ids = batch['input_ids']\n",
    "            attention_mask = batch['attention_mask']\n",
    "            labels = batch['label']\n",
    "\n",
    "            outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "            predictions = torch.argmax(outputs.logits, dim=1)\n",
    "\n",
    "            total_correct += accelerator.gather(predictions.eq(labels)).sum().item()\n",
    "            total_samples += len(labels)\n",
    "\n",
    "    accuracy = total_correct / total_samples\n",
    "    print(f\"Epoch {epoch + 1} - Evaluation Accuracy: {accuracy}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c7ad201-d253-4c52-81b8-18885ea8c782",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bf97faf-eb1a-4c12-a97b-f107f53f05cf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
